#include "statsxx/machine_learning/NeuralNet.hpp"

// STL
#include <fstream>                    // std::ofstream
#include <utility>                    // std::pair

// jScience
#include "jutility.hpp"               // get_n_rand_unique_elements()
#include "jScience/math/utility.hpp"  // ngroups()
#include "jScience/stats/data.hpp"    // shuffle_data()

// stats++
#include "statsxx/dataset.hpp"     // DataSet
//#include "statsxx/optimization/line_search_methods/LineSearchMethod.hpp"
#include "statsxx/machine_learning/NeuralNet/utility.hpp" // fit_validation_error()


//
// DESC:
//
// INPUT:
//          DataSet ds_tr  :: Training dataset
//          DataSet ds_val :: Validation dataset
//
// TODO: This subroutine needs updating.
//
// TODO: ... Either the learning rate or optimal value needs to be set (it is hacked at 0.0001).
//
// TODO: ... Also, options to set PR or FR need to be possible.
//
inline void NEURAL_NET::CG(
                           const int      max_epochs,
                           // -----
                           const bool     early_stopping,
                           // -----
                           const DataSet &ds_tr,
                           const DataSet &ds_val
                           )
{
//    static unsigned int max_epochs = 1000;

    //*********************************************************

    std::vector<double> w;
    for( auto &link : m_links )
    {
        w.push_back(link.weight);
    }

    // CONVG/STORAGE
    std::vector<double> Etr, Etr_stddev, Ev, Ev_stddev;
    std::vector<std::vector<double>> w_store;

    std::ofstream ofs("./training_errs.dat", std::ios::app);

/*
    std::vector<double> gi;
    std::vector<double> gi1;
    std::vector<double> hi;
    calulate_Ederiv( ds_tr, gi );

    for(auto i = 0; i < gi.size(); ++i)
    {
        gi[i] = -gi[i];
    }

    gi1 = gi;
    hi = gi;

    for(int epoch = 0; epoch < max_epochs; ++epoch)
    {
        // calculate the negative gradient at point i ...
        calulate_Ederiv( ds_tr, gi );

        for(auto i = 0; i < gi.size(); ++i)
        {
            gi[i] = -gi[i];
        }

        // compute gamma ...
        double num = 0.0;
        double denom = 0.0;

        for(auto i = 0; i < gi.size(); ++i)
        {
            num += (gi[i] - gi1[i])*gi[i];
            denom += gi1[i]*gi1[i];
        }

        double gamma = num/denom;

        std::cout << "gamma: " << gamma << '\n';
        std::cout << "num: " << num << '\n';
        std::cout << "denom: " << denom << '\n';

        // compute an updated search direction h at point i ...
        for(auto i = 0; i < hi.size(); ++i)
        {
            hi[i] = gi[i] + gamma*hi[i];
        }

        // update weights ...
        for(auto i = 0; i < w.size(); ++i)
        {
            w[i] += hi[i];
        }

        assign_weights(w);

        gi1 = gi;

        // ***
        double Emean;
        double Estddev;
        std::vector<std::vector<double>> NN_out;

        std::tie( Emean, Estddev, NN_out) = parse_TS(ds_tr);
        Etr.push_back(Emean);
        Etr_stddev.push_back(Estddev);

        std::tie( Emean, Estddev, NN_out) = parse_TS(ds_val);
        Ev.push_back(Emean);
        Ev_stddev.push_back(Estddev);

        ofs << (epoch+1) << "   " << Etr.back() << "   " << Etr_stddev.back() << "   " << Ev.back() << "   " << Ev_stddev.back() << std::endl;

        w_store.push_back(w);
        // ****
    }
*/

    std::vector<double> g;
    calulate_Ederiv( ds_tr, g );

    std::vector<double> p(g.size());
    for(auto i = 0; i < g.size(); ++i)
    {
        p[i] = -g[i];
    }
    std::vector<double> g1 = g;

    for(int epoch = 0; epoch < max_epochs; ++epoch)
    {
        // calculate the negative gradient at point i ...
        calulate_Ederiv( ds_tr, g );

        // compute beta ...
        double num = 0.0;
        double denom = 0.0;

        for(auto i = 0; i < g.size(); ++i)
        {
            num += (g[i] - g1[i])*g[i]; // Polak-Ribiere
//            num += g[i]*g[i]; // Fletcher-Reeves
            denom += g1[i]*g1[i];
        }

        double beta = num/denom;

        // compute an updated search direction h at point i ...
        for(auto i = 0; i < p.size(); ++i)
        {
            p[i] = -g[i] + beta*p[i];
        }

        // update weights ...
        for(auto i = 0; i < w.size(); ++i)
        {
            w[i] += 0.00001*p[i];
        }

        assign_weights(w);

        g1 = g;

        // ***
        double Emean;
        double Estddev;
        std::vector<std::vector<double>> NN_out;

        std::tie( Emean, Estddev, NN_out) = parse_TS(ds_tr);
        Etr.push_back(Emean);
        Etr_stddev.push_back(Estddev);

        std::tie( Emean, Estddev, NN_out) = parse_TS(ds_val);
        Ev.push_back(Emean);
        Ev_stddev.push_back(Estddev);

        ofs << (epoch+1) << "   " << Etr.back() << "   " << Etr_stddev.back() << "   " << Ev.back() << "   " << Ev_stddev.back() << std::endl;

        w_store.push_back(w);
        // ****
    }

    ofs.close();

    if( early_stopping )
    {
        // NOTE: See NOTEs in ::backpropagation().

        std::vector<double>::size_type indx;
        std::vector<double>            f;
        std::tie(
                 f,
                 indx
                 ) = fit_validation_error(
                                          Ev,
                                          Ev_stddev
                                          );

        std::ofstream ofs2("fit_valid_err.dat");

        ofs2 << "using point: " << indx << '\n';

        for( auto i = 0; i < f.size(); ++i )
        {
            ofs2 << (i+1) << "   " << f[i] << '\n';
        }

        ofs2.close();

        assign_weights(w_store[indx]);
    }
    else
    {
        assign_weights(w_store.back());
    }
}
